#  This file is part of the NssMPClib project.
#  Copyright (c) 2024 XDU NSS lab,
#  Licensed under the MIT license. See LICENSE in the project root for license information.

"""
Booleans secretly share triples.

Compare the two ways of generating bool triples(Homomorphic Encryption and Trusted Third Parties):

==================      ===============================         =================================================================================================================================
generation method       gen_msb_triples_by_ttp                  gen_msb_triples_by_homomorphic_encryption
security                Trust model based on TTP                Based on homomorphic encryption mechanism
key management          No key generation is involved           Involves the generation and management of keys
applicable scene        Suitable for scenarios that             Suitable for scenarios that require trusted intermediaries
                        require trusted intermediaries
==================      ===============================         =================================================================================================================================
"""
import os
import random
import re

from NssMPC.crypto.aux_parameter._parameter_base import Parameter
from NssMPC.common.ring import RingTensor
from NssMPC.config import param_path
from NssMPC.crypto.primitives.homomorphic_encryption.paillier import Paillier


class BooleanTriples(Parameter):
    """
    This is a parameter class for handling Boolean triples, which allows the generation and saving of Boolean triples.
    """

    def __init__(self, a=None, b=None, c=None):
        """
        Initializes the a, b, and c of the multiplicative triple and sets the size of the elements in the triple to 0.
        """
        self.a = a
        self.b = b
        self.c = c
        self.size = 0

    def __iter__(self):
        """
        Iterator support.

        Allows the BooleanTriples instance to be iterated over, returning each element of the triple.

        :return: Returns an iterator for boolean triples.
        :rtype: iterator
        """
        return iter((self.a, self.b, self.c))

    @staticmethod
    def gen(num_of_triples, num_of_party=2, type_of_generation='TTP', party=None):
        """
        Generate a specified number of Boolean triples according to different generation methods.

        Call the corresponding function based on the value of ``type_of_generation`` to generate Boolean triples:
            * :func:`gen_msb_triples_by_homomorphic_encryption` : Generate triples based on homomorphic encryption.
            * :func:`gen_msb_triples_by_ttp` : Generates triples based on trusted third parties.

        :param num_of_triples: number of triples
        :type num_of_triples: int
        :param num_of_party: number of parties
        :type num_of_party: int
        :param type_of_generation: generation type. (TTP: generated by trusted third party, HE: generated by homomorphic encryption)
        :type type_of_generation: str
        :param party: party if HE
        :type party: Party
        """
        if type_of_generation == 'HE':
            return gen_msb_triples_by_homomorphic_encryption(num_of_triples, party)
        elif type_of_generation == 'TTP':
            return gen_msb_triples_by_ttp(num_of_triples, num_of_party)

    @classmethod
    def gen_and_save(cls, num, saved_name=None, num_of_party=2, saved_path=None, type_of_generation='TTP', party=None):
        """
        Generates a Boolean triple and saves it to the specified path.

        First asserts that ``type_of_generation`` is a valid value ('HE' or 'TTP'), and then generates the required
        Boolean triple. If saved_path is not specified, the default path is used. If the path does not exist,
        create it.
            * For triples generated by 'TTP', loop through each participant and save the file.
            * For triples generated by 'HE', the existing file name is checked to avoid overwriting and a new file name is created for saving based on the ID of the current participant.

        :param num: number of triples
        :type num: int
        :param saved_name: Optional file name to save.
        :type saved_name: str
        :param type_of_generation: generation type. (TTP: generated by trusted third party, HE: generated by homomorphic encryption)
        :type type_of_generation: str
        :param saved_path: saved path for the generated params
        :type saved_path: str
        :param party: party if HE
        :type party: Party
        """
        assert type_of_generation in ['HE', 'TTP']
        triples = cls.gen(num, num_of_party, type_of_generation, party)
        if saved_path is None:
            file_path = f"{param_path}BooleanTriples/"
        else:
            file_path = saved_path
        if not os.path.exists(file_path):
            os.makedirs(file_path)

        if saved_name is None:
            saved_name = 'BooleanTriples'

        if type_of_generation == 'TTP':
            for i in range(num_of_party):
                file_name = f"{saved_name}_{i}.pth"
                triples[i].save(file_path=file_path, name=file_name)

        # todo :未作HE适配
        else:
            file_path = f"{param_path}BooleanTriples/{num_of_party}party/"
            if not os.path.exists(file_path):
                os.makedirs(file_path)

            file_names = os.listdir(file_path)
            max_ptr = 0
            for fname in file_names:
                match = re.search(rf"BooleanTriples_{party.party_id}+_(\d+)\.pth", fname)
                if match:
                    max_ptr = max(max_ptr, int(match.group(1)))

            file_name = f"BooleanTriples_{party.party_id}_{max_ptr + 1}.pth"
            triples.save(file_path + file_name)


def gen_msb_triples_by_ttp(bit_len, num_of_party=2):
    """
    Generate the multiplication Beaver triple by trusted third party.

    First, two binary tensors, ``a`` and ``b``, are randomly generated with values between 0 and 1. Then ``c = a & b`` is
    calculated, that is, the bitwise and operation is performed. ``a``, ``b``, and ``c`` are shared to each participant using
    Boolean Secret Sharing, and each participant's shared value is converted to a tensor on the *CPU* device and stored
    in the triples list.

    :param bit_len: the length of the binary string
    :type bit_len: int
    :param num_of_party: the number of parties
    :type num_of_party: int
    :returns: the list of msb triples generated by TTP.
    :rtype: list

    """
    a = RingTensor.random([bit_len], down_bound=0, upper_bound=2, device='cpu')
    b = RingTensor.random([bit_len], down_bound=0, upper_bound=2, device='cpu')
    c = a & b
    from NssMPC.crypto.primitives.boolean_secret_sharing.boolean_secret_sharing import BooleanSecretSharing
    a_list = BooleanSecretSharing.share(a, num_of_party)
    b_list = BooleanSecretSharing.share(b, num_of_party)
    c_list = BooleanSecretSharing.share(c, num_of_party)

    triples = []
    for i in range(num_of_party):
        triples.append(BooleanTriples())
        triples[i].a = a_list[i].to('cpu')
        triples[i].b = b_list[i].to('cpu')
        triples[i].c = c_list[i].to('cpu')
        triples[i].size = bit_len

    return triples


def gen_msb_triples_by_homomorphic_encryption(bit_len, party):
    """
    Generate the multiplication Beaver triple by homomorphic encryption.

    First, two binary lists a and b are randomly generated, along with an empty list **c**.
        If the current participant`s party_id is **0**:
            First, the key of Paillier homomorphic encryption is generated to encrypt a and b, and then the encrypted
            value is sent to other participants, and the data from other participants is received, and the value of c is
            decrypted.

        If the current participant`s party_id is **1**:
            The random value ``r`` is first generated to compute the
            partial value of ``c``. After receiving the message from Party 0, ``r`` is encrypted using homomorphic encryption,
            and a new encrypted value ``d`` is calculated based on the received data, and then sent back.

    :param bit_len: the length of the binary string
    :type bit_len: int
    :param party: the party to generate triples
    :type party: str
    :returns: A BooleanTriples that includes a, b, and c.
    :rtype: BooleanTriples
    """

    a = [random.randint(0, 2) for _ in range(bit_len)]
    b = [random.randint(0, 2) for _ in range(bit_len)]
    c = []

    if party.party_id == 0:
        paillier = Paillier()
        paillier.gen_keys()

        encrypted_a = paillier.encrypt(a)
        encrypted_b = paillier.encrypt(b)
        party.send([encrypted_a, encrypted_b, paillier.public_key])
        d = party.receive()
        decrypted_d = paillier.decrypt(d)
        c = [decrypted_d[i] + a[i] * b[i] for i in range(bit_len)]

    elif party.party_id == 1:
        r = [random.randint(0, 2) for _ in range(bit_len)]
        c = [a[i] * b[i] - r[i] for i in range(bit_len)]

        messages = party.receive()

        encrypted_r = Paillier.encrypt_with_key(r, messages[2])
        d = [messages[0][i] ** b[i] * messages[1][i] ** a[i] * encrypted_r[i] for i in range(bit_len)]
        party.send(d)

    msb_triples = BooleanTriples(RingTensor(a).to('cpu'), RingTensor(b).to('cpu'),
                                 RingTensor(c).to('cpu'))
    return msb_triples
